# -*- coding: utf-8 -*-

import torch
import argparse
from collections import OrderedDict

def change_model(args):
    nkd_model = torch.load(args.nkd_path)
    all_name = []
    for name, v in nkd_model["state_dict"].items():
        if name.startswith("student."):
            all_name.append((name[8:], v))
        else:
            continue
    state_dict = OrderedDict(all_name)
    nkd_model['state_dict'] = state_dict
    torch.save(nkd_model, args.output_path) 

           
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Transfer CKPT')
    parser.add_argument('--nkd_path', type=str, default='work_dirs/nkd_res34_distill_res18_img/latest.pth', 
                        metavar='N',help='nkd_model path')
    parser.add_argument('--output_path', type=str, default='res18_new.pth',metavar='N', 
                        help = 'output path')
    args = parser.parse_args()
    change_model(args)
